import numpy as np

from centralized_verification.agents.decentralized_training.independent_agents.q_learner import QLearner
from centralized_verification.utils import convert_gym_space_to_q_shape


class TabularQLearner(QLearner):
    def __init__(self, *args, alpha_index=1, **kwargs):
        super().__init__(*args, **kwargs)
        obs_idx = convert_gym_space_to_q_shape(self.obs_space)
        self.q_table = np.zeros((*obs_idx, self.num_actions))
        self.alpha_index = alpha_index

    def get_greedy_action(self, observation):
        return int(np.argmax(self.q_table[observation]))

    def update_q(self, obs, action, next_obs, rew, done, step_num, training_progress):
        alpha = self.alpha_index / (0.1 * step_num + 1)
        next_max_q_value = self.q_table[next_obs].max()
        target_q_value = rew + (0 if done else self.discount * next_max_q_value)
        self.q_table[obs][action] = (1 - alpha) * self.q_table[obs][action] + alpha * target_q_value

    def state_dict(self):
        return {
            "q_table": self.q_table
        }

    def load_state_dict(self, state_dict):
        self.q_table = state_dict["q_table"]
